"""A HTTP API transport that offers optional local caching of the results."""

import datetime
import enum
import hashlib
import json
import os
import pathlib
import platform
import re
from contextlib import contextmanager
from importlib.metadata import version
from json import JSONDecodeError
from typing import Optional, Callable, Set, Union, Collection, Dict, Literal, Tuple
import shutil
import logging
from pathlib import Path

import pandas
import pandas as pd
import requests
from filelock import FileLock
from requests import Response
from requests.adapters import HTTPAdapter

from tradingstrategy.candle import TradingPairDataAvailability
from tradingstrategy.chain import ChainId
from tradingstrategy.exchange import ExchangeUniverse
from tradingstrategy.timebucket import TimeBucket
from tradingstrategy.transport.jsonl import load_candles_jsonl
from tradingstrategy.types import PrimaryKey
from tradingstrategy.lending import LendingCandle, LendingCandleType
from urllib3 import Retry

from tradingstrategy.utils.time import naive_utcfromtimestamp

logger = logging.getLogger(__name__)


class APIError(Exception):
    """API error parent class."""


class DataNotAvailable(APIError):
    """Data not available.

    This may happen e.g. when a new entry has just come online,
    it has been added to the pair or reserve map, but does not have candles available yet.

    Wraps 404 error from the dataset server.local
    """


class CacheStatus(enum.Enum):
    """When reading cached files, report to the caller about the caching status."""
    cached = "cached"
    cached_with_timestamped_name = "cached_with_timestamped_name"
    missing = "missing"
    expired = "expired"

    def is_readable(self):
        return self in (CacheStatus.cached, CacheStatus.cached_with_timestamped_name)


class CachedHTTPTransport:
    """A HTTP API transport that offers optional local caching of the results.

    - Download live and cached datasets from the candle server and cache locally
      on the filesystem

    - The download files are very large and expect to need several gigabytes of space for them

    - Has a default HTTP retry policy in the case network or server flakiness

    """

    def __init__(self,
                 download_func: Callable,
                 endpoint: Optional[str] = None,
                 cache_period =datetime.timedelta(days=3),
                 cache_path: Optional[str] = None,
                 api_key: Optional[str] = None,
                 timeout: float = 15.0,
                 add_exception_hook=True,
                 retry_policy: Optional[Retry] = None):
        """
        :param download_func: Interactive download progress bar displayed during the download
        :param endpoint: API server we are using - default is `https://tradingstrategy.ai/api`
        :param cache_period: How many days we store the downloaded files
        :param cache_path: Where we store the downloaded files
        :param api_key: Trading Strategy API key to use download
        :param timeout: requests HTTP lib timeout
        :param add_exception_hook: Automatically raise an error in the case of HTTP error. Prevents auto retries.
        :param retry_policy:

            How to handle failed HTTP requests.
            If not given use the default somewhat graceful retry policy.
        """

        self.download_func = download_func

        if endpoint:
            self.endpoint = endpoint
        else:
            self.endpoint = "https://tradingstrategy.ai/api"

        self.cache_period = cache_period

        if cache_path:
            self.cache_path = cache_path
        else:
            self.cache_path = os.path.expanduser("~/.cache/tradingstrategy")

        self.requests = self.create_requests_client(
            api_key=api_key,
            retry_policy=retry_policy,
            add_exception_hook=add_exception_hook,
        )

        self.api_key = api_key
        self.timeout = timeout

    def close(self):
        """Release any underlying sockets."""
        self.requests.close()

    def create_requests_client(self,
                               retry_policy: Optional[Retry]=None,
                               api_key: Optional[str] = None,
                               add_exception_hook=True):
        """Create HTTP 1.1 keep-alive connection to the server with optional authorization details.

        :param add_exception_hook: Automatically raise an error in the case of HTTP error
        """

        session = requests.Session()

        # Set up dealing with network connectivity flakey
        if retry_policy is None:
            # https://stackoverflow.com/a/35504626/315168
            retry_policy = Retry(
                total=5,
                backoff_factor=0.1,
                status_forcelist=[ 500, 502, 503, 504 ],
            )
            session.mount('http://', HTTPAdapter(max_retries=retry_policy))
            session.mount('https://', HTTPAdapter(max_retries=retry_policy))

        if api_key:
            session.headers.update({'Authorization': api_key})

        # - Add default HTTP request retry policy to the client
        package_version = version("trading-strategy")
        system = platform.system()
        release = platform.release()
        session.headers.update({"User-Agent": f"trading-strategy {package_version} on {system} {release}"})

        if add_exception_hook:
            def exception_hook(response: Response, *args, **kwargs):
                if response.status_code == 404:
                    raise DataNotAvailable(f"Server error reply: code:{response.status_code} message:{response.text}")
                elif response.status_code >= 400:
                    raise APIError(f"Server error reply: code:{response.status_code} message:{response.text}")

            session.hooks = {
                "response": exception_hook,
            }
        return session

    def get_abs_cache_path(self):
        return os.path.abspath(self.cache_path)

    def get_cached_file_path(self, fname):
        path = os.path.join(self.get_abs_cache_path(), fname)
        return path

    def get_cached_item(self, fname: Union[str, pathlib.Path]) -> Optional[pathlib.Path]:
        """Get a cached file.

        - Return ``None`` if the cache has expired

        - The cache timeout is coded in the file modified
          timestamp (mtime)
        """

        path = self.get_cached_file_path(fname)
        if not os.path.exists(path):
            # Cached item not yet created
            return None

        f = pathlib.Path(path)

        end_time_pattern = r"-to_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}"
        if re.search(end_time_pattern, str(fname)):
            # Candle files with an end time never expire, as the history does not change
            return f

        mtime = datetime.datetime.fromtimestamp(f.stat().st_mtime)
        if datetime.datetime.now() - mtime > self.cache_period:
            # File cache expired
            return None

        return f

    def get_cached_item_with_status(
            self,
            fname: Union[str, pathlib.Path]
    ) -> Tuple[pathlib.Path | None, CacheStatus]:
        """Get a cached file.

        - Return ``None`` if the cache has expired

        - The cache timeout is coded in the file modified
          timestamp (mtime)
        """

        path = self.get_cached_file_path(fname)
        if not os.path.exists(path):
            # Cached item not yet created
            return None, CacheStatus.missing

        f = pathlib.Path(path)

        # For some datasets, we encode the end-tie in the fname
        end_time_pattern = r"-to_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}"
        if re.search(end_time_pattern, str(fname)):
            # Candle files with an end time never expire, as the history does not change
            return f, CacheStatus.cached_with_timestamped_name

        mtime = datetime.datetime.fromtimestamp(f.stat().st_mtime)
        if datetime.datetime.now() - mtime > self.cache_period:
            # File cache expired
            return None, CacheStatus.expired

        return f, CacheStatus.cached

    def _generate_cache_name(
        self,
        pair_ids: Set[id],
        time_bucket: TimeBucket,
        start_time: Optional[datetime.datetime] = None,
        end_time: Optional[datetime.datetime] = None,
        max_bytes: Optional[int] = None,
        candle_type: str = "candles",
    ) -> str:
        """Generate the name of the file for holding cached candle data for ``pair_ids``.
        """
        # Meaningfully truncate the timestamp to align with the target time bucket.
        if end_time:
            candle_width = time_bucket.to_timedelta()
            trunc = {"second": 0}
            if candle_width >= datetime.timedelta(hours=1):
                trunc["minute"] = 0
            if candle_width >= datetime.timedelta(days=1):
                trunc["hour"] = 0
            end_time = end_time.replace(**trunc)

        # Create a compressed cache key for the filename,
        # as we have 256 char limit on fname lenghts
        full_cache_key = (
            f"{pair_ids}{time_bucket}{start_time}{end_time}{max_bytes}"
        )
        md5 = hashlib.md5(full_cache_key.encode("utf-8")).hexdigest()

        # If exists, include the end time info in filename for cache invalidation logic.
        if start_time:
            start_part = start_time.strftime("%Y-%m-%d_%H-%M-%S")
        else:
            start_part = "any"

        end_part = end_time.strftime("%Y-%m-%d_%H-%M-%S") if end_time else "any"

        return f"{candle_type.replace('_', '-')}-jsonl-{time_bucket.value}-between-{start_part}-and-{end_part}-{md5}.parquet"

    def purge_cache(self, filename: Optional[Union[str, pathlib.Path]] = None):
        """Delete all cached files on the filesystem.

        :param filename:
            If given, remove only that specific file, otherwise clear all cached data.
        """
        target_path = self.cache_path if filename is None else filename

        logger.info("Purging caches at %s", target_path)
        try:
            if os.path.isdir(target_path):
                shutil.rmtree(target_path)
            else:
                os.remove(target_path)
        except FileNotFoundError as exc:
            logger.warning(
                f"Attempted to purge caches, but no such file or directory: {exc.filename}"
            )

    def save_response(self, fpath, api_path, params=None, human_readable_hint: Optional[str]=None):
        """Download a file to the cache and display a pretty progress bar while doing it.

        :param fpath:
            File system path where the download will be saved

        :param api_path:
            Which Trading Strategy backtesting API we call to download the dataset.

        :param params:
            HTTP request params, like the `Authorization` header

        :param human_readable_hint:
            The status text displayed on the progress bar what's being downloaded
        """
        os.makedirs(self.get_abs_cache_path(), exist_ok=True)
        url = f"{self.endpoint}/{api_path}"
        logger.debug("Saving %s to %s", url, fpath)
        # https://stackoverflow.com/a/14114741/315168
        self.download_func(self.requests, fpath, url, params, self.timeout, human_readable_hint)

    def get_json_response(self, api_path, params=None):
        url = f"{self.endpoint}/{api_path}"
        response = self.requests.get(url, params=params)
        return response.json()

    def post_json_response(self, api_path, params=None):
        url = f"{self.endpoint}/{api_path}"
        response = self.requests.post(url, params=params)
        return response.json()

    def fetch_chain_status(self, chain_id: int) -> dict:
        """Not cached."""
        return self.get_json_response("chain-status", params={"chain_id": chain_id})

    def fetch_pair_universe(self) -> pathlib.Path:
        fname = "pair-universe.parquet"
        cached = self.get_cached_item(fname)

        # Download save the file
        path = self.get_cached_file_path(fname)

        with wait_other_writers(path):

            if cached:
                return cached

            self.save_response(path, "pair-universe", human_readable_hint="Downloading trading pair dataset")
            return self.get_cached_item(fname)

    def fetch_exchange_universe(self) -> pathlib.Path:
        fname = "exchange-universe.json"

        # Download save the file
        path = self.get_cached_file_path(fname)

        with wait_other_writers(path):

            cached = self.get_cached_item(fname)
            if cached:
                return cached

            self.save_response(path, "exchange-universe", human_readable_hint="Downloading exchange dataset")

            _check_good_json(path, "fetch_exchange_universe() failed")

            return self.get_cached_item(fname)
    
    def fetch_lending_reserve_universe(self) -> pathlib.Path:
        fname = "lending-reserve-universe.json"
        cached = self.get_cached_item(fname)

        # Download save the file
        path = self.get_cached_file_path(fname)

        with wait_other_writers(path):

            if cached:
                return cached

            self.save_response(
                path,
                "lending-reserve-universe",
                human_readable_hint="Downloading lending reserve dataset"
            )
            path, status = self.get_cached_item_with_status(fname)

            _check_good_json(path, "fetch_lending_reserve_universe() failed")

            assert status.is_readable(), f"Got status {status} for path"
            return path

    def fetch_candles_all_time(self, bucket: TimeBucket) -> pathlib.Path:
        """Load candles and return a cached file where they are stored.

        - If cached file exists return it directly

        - Wait if someone else is writing the file
          (in multiple parallel testers)
        """
        assert isinstance(bucket, TimeBucket)
        fname = f"candles-{bucket.value}.parquet"
        cached_path = self.get_cached_file_path(fname)

        with wait_other_writers(cached_path):

            cached = self.get_cached_item(fname)
            if cached:
                # Cache exists and is not expired
                return cached

            # Download save the file
            params = {"bucket": bucket.value}
            self.save_response(cached_path, "candles-all", params, human_readable_hint=f"Downloading OHLCV data for {bucket.value} time bucket")
            logger.info(
                "Saved %s as with params %s, down",
                cached_path,
                params
            )
            saved, status = self.get_cached_item_with_status(fname)
            # Troubleshoot multiple test workers race condition
            assert status.is_readable(), f"Cache status {status} with save_response() generated for {fname}, cached path is {cached_path}, download_func is {self.download_func}"
            return saved

    def fetch_liquidity_all_time(self, bucket: TimeBucket) -> pathlib.Path:
        fname = f"liquidity-samples-{bucket.value}.parquet"
        path = self.get_cached_file_path(fname)

        with wait_other_writers(path):

            cached = self.get_cached_item(fname)
            if cached:
                return cached
            # Download save the file
            self.save_response(path, "liquidity-all", params={"bucket": bucket.value}, human_readable_hint=f"Downloading liquidity data for {bucket.value} time bucket")
            return self.get_cached_item(path)

    def fetch_lending_reserves_all_time(self) -> pathlib.Path:
        fname = "lending-reserves-all.parquet"

        # Download save the file
        path = self.get_cached_file_path(fname)

        with wait_other_writers(path):

            cached = self.get_cached_item(fname)
            if cached:
                return cached

            # We only have Aave v3 data for now...
            self.save_response(
                path,
                "aave-v3-all",
                human_readable_hint="Downloading Aave v3 reserve dataset",
            )
            assert os.path.exists(path)
            item, status = self.get_cached_item_with_status(path)
            assert status.is_readable(), f"File not readable after save cached:{cached} fname:{fname} path:{path}"
            return item

    
    def fetch_lending_candles_by_reserve_id(
        self,
        reserve_id: int,
        time_bucket: TimeBucket,
        candle_type: LendingCandleType = LendingCandleType.variable_borrow_apr,
        start_time: Optional[datetime.datetime] = None,
        end_time: Optional[datetime.datetime] = None,
    ) -> pd.DataFrame:
        """Load particular set of the lending candles and cache the result.

        For the candles format see :py:mod:`tradingstrategy.lending`.

        :param reserve_id:
            Lending reserve's internal id we query data for.
            Get internal id from lending reserve universe dataset.

        :param time_bucket:
            Candle time frame.

        :param candle_type:
            Lending candle type.

        :param start_time:
            All candles after this.
            If not given start from genesis.

        :param end_time:
            All candles before this

        :return:
            Lending candles dataframe
        """

        assert  isinstance(time_bucket, TimeBucket)
        assert isinstance(candle_type, LendingCandleType)

        cache_fname = self._generate_cache_name(
            reserve_id,
            time_bucket,
            start_time,
            end_time,
            candle_type=candle_type.name,
        )

        full_fname = self.get_cached_file_path(cache_fname)

        with wait_other_writers(full_fname):

            cached = self.get_cached_item(cache_fname)

            if cached:
                logger.debug("Using cached data file %s", full_fname)
                return pandas.read_parquet(cached)

            api_url = f"{self.endpoint}/lending-reserve/candles"

            params = {
                "reserve_id": reserve_id,
                "time_bucket": time_bucket.value,
                "candle_types": candle_type,
            }

            if start_time:
                params["start"] = start_time.isoformat()

            if end_time:
                params["end"] = end_time.isoformat()

            try:
                resp = self.requests.get(api_url, params=params, stream=True)
            except DataNotAvailable as e:
                # We have special request hook that translates 404 to this exception
                raise DataNotAvailable(f"Could not fetch lending candles for {params}") from e
            except Exception as e:
                raise APIError(f"Could not fetch lending candles for {params}") from e

            # TODO: handle error
            candles = resp.json()[candle_type]

            df = LendingCandle.convert_web_candles_to_dataframe(candles)

            # Update cache
            path = self.get_cached_file_path(cache_fname)
            df.to_parquet(path)

            size = pathlib.Path(path).stat().st_size
            logger.debug(f"Wrote {cache_fname}, disk size is {size:,}b")

            return df

    def ping(self) -> dict:
        reply = self.get_json_response("ping")
        return reply

    def message_of_the_day(self) -> dict:
        reply = self.get_json_response("message-of-the-day")
        return reply

    def register(self, first_name, last_name, email) -> dict:
        """Makes a register request.

        The request does not load any useful payload, but it is assumed the email message gets verified
        and the user gets the API from the email.
        """
        reply = self.post_json_response("register", params={"first_name": first_name, "last_name": last_name, "email": email})
        return reply

    def fetch_candles_by_pair_ids(
            self,
            pair_ids: Set[id],
            time_bucket: TimeBucket,
            start_time: Optional[datetime.datetime] = None,
            end_time: Optional[datetime.datetime] = None,
            max_bytes: Optional[int] = None,
            progress_bar_description: Optional[str] = None,
    ) -> pd.DataFrame:
        """Load particular set of the candles and cache the result.

        If there is no cached result, load using JSONL.

        More information in :py:mod:`tradingstrategy.transport.jsonl`.

        For the candles format see :py:mod:`tradingstrategy.candle`.

        :param pair_ids:
            Trading pairs internal ids we query data for.
            Get internal ids from pair dataset.

        :param time_bucket:
            Candle time frame

        :param start_time:
            All candles after this.
            If not given start from genesis.

        :param end_time:
            All candles before this

        :param max_bytes:
            Limit the streaming response size

        :param progress_bar_description:
            Display on downlood progress bar

        :return:
            Candles dataframe
        """
        cache_fname = self._generate_cache_name(
            pair_ids, time_bucket, start_time, end_time, max_bytes
        )

        full_fname = self.get_cached_file_path(cache_fname)

        with wait_other_writers(full_fname):

            cached = self.get_cached_item(cache_fname)

            if cached:
                logger.debug("Using cached JSONL data file %s", full_fname)
                return pandas.read_parquet(cached)

            df: pd.DataFrame = load_candles_jsonl(
                self.requests,
                self.endpoint,
                pair_ids,
                time_bucket,
                start_time,
                end_time,
                max_bytes=max_bytes,
                progress_bar_description=progress_bar_description,
            )

            # Update cache
            path = self.get_cached_file_path(cache_fname)
            df.to_parquet(path)

            size = pathlib.Path(path).stat().st_size
            logger.debug(f"Wrote {cache_fname}, disk size is {size:,}b")

            return df

    def fetch_trading_data_availability(self,
          pair_ids: Collection[PrimaryKey],
          time_bucket: TimeBucket,
        ) -> Dict[PrimaryKey, TradingPairDataAvailability]:
        """Check the trading data availability at oracle's real time market feed endpoint.

        - Trading Strategy oracle uses sparse data format where candles
          with zero trades are not generated. This is better suited
          for illiquid DEX markets with few trades.

        - Because of sparse data format, we do not know if there is a last
          candle available - candle may not be available yet or there might not be trades
          to generate a candle

        - This endpoint allows to check the trading data availability for multiple of trading pairs.

        - This endpoint is public

        :param pair_ids:
            Trading pairs internal ids we query data for.
            Get internal ids from pair dataset.

        :param time_bucket:
            Candle time frame

        :return:
            Map of pairs -> their trading data availability

        """

        params = {
            "pair_ids": ",".join([str(i) for i in pair_ids]),  # OpenAPI comma delimited array
            "time_bucket":  time_bucket.value,
        }

        array = self.get_json_response("trading-pair-data-availability", params=params)

        # Make to typed and deseralise
        def _convert(p: dict) -> TradingPairDataAvailability:
            try:
                return {
                    "chain_id": ChainId(p["chain_id"]),
                    "pair_id": p["pair_id"],
                    "pair_address": p["pair_address"],
                    "last_trade_at": datetime.datetime.fromisoformat(p["last_trade_at"]),
                    "last_candle_at": datetime.datetime.fromisoformat(p["last_candle_at"]),
                    "last_supposed_candle_at": datetime.datetime.fromisoformat(p["last_supposed_candle_at"]),
                }
            except Exception as e:
                raise RuntimeError(f"Failed to convert: {p}") from e

        return {p["pair_id"]: _convert(p) for p in array}


@contextmanager
def wait_other_writers(path: Path | str, timeout=120):
    """Wait other potential writers writing the same file.

    - Work around issues when parallel unit tests and such
      try to write the same file

    Example:

    .. code-block:: python

        import urllib
        import tempfile

        import pytest
        import pandas as pd

        @pytest.fixture()
        def my_cached_test_data_frame() -> pd.DataFrame:

            # Al tests use a cached dataset stored in the /tmp directory
            path = os.path.join(tempfile.gettempdir(), "my_shared_data.parquet")

            with wait_other_writers(path):

                # Read result from the previous writer
                if not path.exists(path):
                    # Download and write to cache
                    urllib.request.urlretrieve("https://example.com", path)

                return pd.read_parquet(path)

    :param path:
        File that is being written

    :param timeout:
        How many seconds wait to acquire the lock file.

        Default 2 minutes.

    :raise filelock.Timeout:
        If the file writer is stuck with the lock.
    """

    if type(path) == str:
        path = Path(path)

    assert isinstance(path, Path), f"Not Path object: {path}"

    assert path.is_absolute(), f"Did not get an absolute path: {path}\n" \
                               f"Please use absolute paths for lock files to prevent polluting the local working directory."

    # If we are writing to a new temp folder, create any parent paths
    os.makedirs(path.parent, exist_ok=True)

    # https://stackoverflow.com/a/60281933/315168
    lock_file = path.parent / (path.name + '.lock')

    lock = FileLock(lock_file, timeout=timeout)

    if lock.is_locked:
        logger.info(
            "Parquet file %s locked for writing, waiting %f seconds",
            path,
            timeout,
        )

    with lock:
        yield


def _check_good_json(path: Path, exception_message: str):
    """Check that server gave us good JSON file.

    - 404, 500, API key errors

    """
    broken_data = False

    # Quick fix to avoid getting hit by API key errors here.
    # TODO: Clean this up properly
    with open(path, "rt", encoding="utf-8") as inp:
        data = inp.read()
        try:
            data = json.loads(data)
            if "error" in data:
                broken_data = True
        except JSONDecodeError as e:
            broken_data = True

    if broken_data:
        os.remove(path)
        raise RuntimeError(f"{exception_message}\nJSON data is: {data}")

